#!/usr/bin/env python3
"""
DMLC submission script by ssh

One need to make sure all slaves machines are ssh-able.
"""
from __future__ import absolute_import

import logging
import os
import subprocess
from multiprocessing import Pool
from threading import Thread

from . import tracker


def sync_dir(local_dir, slave_node, slave_dir):
    """
    sync the working directory from root node into slave node
    """
    remote = slave_node[0] + ":" + slave_dir
    logging.info("rsync %s -> %s", local_dir, remote)
    prog = 'rsync -az --rsh="ssh -o StrictHostKeyChecking=no -p %s" %s %s' % (
        slave_node[1],
        local_dir,
        remote,
    )
    subprocess.check_call([prog], shell=True)


def get_env(pass_envs):
    envs = []
    # get system envs
    keys = [
        "OMP_NUM_THREADS",
        "KMP_AFFINITY",
        "LD_LIBRARY_PATH",
        "AWS_ACCESS_KEY_ID",
        "AWS_SECRET_ACCESS_KEY",
        "DMLC_INTERFACE",
    ]
    for k in keys:
        v = os.getenv(k)
        if v is not None:
            envs.append("export " + k + "=" + v + ";")
    # get ass_envs
    for k, v in pass_envs.items():
        envs.append("export " + str(k) + "=" + str(v) + ";")
    return " ".join(envs)


def submit(args):
    assert args.host_file is not None
    with open(args.host_file) as f:
        tmp = f.readlines()
    assert len(tmp) > 0
    hosts = []
    for h in tmp:
        if len(h.strip()) > 0:
            # parse addresses of the form ip:port
            h = h.strip()

            # parse mpi host file form ip slots=??
            # this is to create an unified api for mpi and ssh
            i = h.find("slots=")
            if i != -1:
                h = h[:i].strip()

            i = h.find(":")
            p = "22"
            if i != -1:
                p = h[i + 1 :]
                h = h[:i]
            # hosts now contain the pair ip, port
            hosts.append((h, p))

    def ssh_submit(nworker, nserver, pass_envs):
        """
        customized submit script
        """

        # thread func to run the job
        def run(prog):
            subprocess.check_call(prog, shell=True)

        # sync programs if necessary
        local_dir = os.getcwd() + "/"
        working_dir = local_dir
        if args.sync_dst_dir is not None and args.sync_dst_dir != "None":
            working_dir = args.sync_dst_dir
            pool = Pool(processes=len(hosts))
            for h in hosts:
                pool.apply_async(sync_dir, args=(local_dir, h, working_dir))
            pool.close()
            pool.join()

        # launch jobs
        for i in range(nworker + nserver):
            pass_envs["DMLC_ROLE"] = "server" if i < nserver else "worker"
            (node, port) = hosts[i % len(hosts)]
            pass_envs["DMLC_NODE_HOST"] = node
            prog = (
                get_env(pass_envs)
                + " cd "
                + working_dir
                + "; "
                + (" ".join(args.command))
            )
            prog = (
                "ssh -o StrictHostKeyChecking=no "
                + node
                + " -p "
                + port
                + " '"
                + prog
                + "'"
            )
            thread = Thread(target=run, args=(prog,))
            thread.setDaemon(True)
            thread.start()

        return ssh_submit

    tracker.submit(
        args.num_workers,
        args.num_servers,
        fun_submit=ssh_submit,
        pscmd=(" ".join(args.command)),
        hostIP=args.host_ip,
    )
